{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Install seaborn to colorful visualization!\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "getattr(tqdm, '_instances', {}).clear()\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['figure.dpi']= 300\n",
    "import matplotlib.pyplot as plt\n",
    "#from IPython.display import Video, HTML\n",
    "\n",
    "# for visualization\n",
    "import cv2\n",
    "from moviepy.editor import *\n",
    "\n",
    "from src.utils import io_utils, eval_utils\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "plt.rc('xtick', labelsize=10)    # fontsize of the tick labels\n",
    "plt.rc('ytick', labelsize=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image, ImageDraw, ImageFont\n",
    "\n",
    "FAST = True\n",
    "\n",
    "def expand_mask(mask, margin=2, height=12):\n",
    "    w = mask.shape[1]\n",
    "    out = [np.zeros((1,w,3), dtype=np.int) for i in range(margin)]\n",
    "    for i in range(height):\n",
    "        out.append(mask)\n",
    "    for i in range(margin):\n",
    "        out.append(np.zeros((1,w,3), dtype=np.int))\n",
    "    return np.concatenate(out, axis=0)\n",
    "\n",
    "def text_phantom(text, width=480):\n",
    "    # Availability is platform dependent\n",
    "    font = 'DejaVuSans-Bold'\n",
    "\n",
    "    # Create font\n",
    "    pil_font = ImageFont.truetype(font + \".ttf\", size=16,\n",
    "                                  encoding=\"unic\")\n",
    "    text_width, text_height = pil_font.getsize(text)\n",
    "\n",
    "    # create a blank canvas with extra space between lines\n",
    "    canvas = Image.new('RGB', [width,20], (255, 255, 255))\n",
    "\n",
    "    # draw the text onto the canvas\n",
    "    draw = ImageDraw.Draw(canvas)\n",
    "    white = \"#000000\"\n",
    "    draw.text((0,0), text, font=pil_font, fill=white)\n",
    "\n",
    "    # (text, background): (black, while) -> (white, black)\n",
    "    return 255 - np.asarray(canvas)\n",
    "\n",
    "def sampling_idx(preds, policy=\"random\"):\n",
    "    idx = random.randint(0, len(preds[\"gts\"])-1)\n",
    "    if policy == \"random\":\n",
    "        return idx\n",
    "    elif policy == \"success\":\n",
    "        pred = preds[\"predictions\"][idx][0]\n",
    "        gt = preds[\"gts\"][idx]\n",
    "        while eval_utils.compute_tiou(pred, gt) < 0.8 or preds[\"gts\"][idx][0] < 15:\n",
    "            idx = random.randint(0, len(preds[\"gts\"])-1)\n",
    "            pred = preds[\"predictions\"][idx][0]\n",
    "            gt = preds[\"gts\"][idx]\n",
    "    elif policy == \"failure\":\n",
    "        pred = preds[\"predictions\"][idx][0]\n",
    "        gt = preds[\"gts\"][idx]\n",
    "        while eval_utils.compute_tiou(pred, gt) > 0.2:\n",
    "            idx = random.randint(0, len(preds[\"gts\"])-1)\n",
    "            pred = preds[\"predictions\"][idx][0]\n",
    "            gt = preds[\"gts\"][idx]\n",
    "    return idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_bar(gt, pred, vlen, wbar):\n",
    "    # draw bar for GT and prediction\n",
    "    gt_idx = np.round(np.asarray(gt) / vlen * wbar).astype(np.int)\n",
    "    pred_idx = np.round(np.asarray(pred) / vlen * wbar).astype(np.int)\n",
    "    gt_mask, pred_mask = np.zeros((1,wbar,3)), np.zeros((1,wbar,3))\n",
    "    gt_mask[0, gt_idx[0]:gt_idx[1], 0] = 255 # Red color\n",
    "    pred_mask[0, pred_idx[0]:pred_idx[1], 2] = 255 # blue color\n",
    "\n",
    "    # expand masks for better visualization and concatenate them\n",
    "    bar = np.concatenate([expand_mask(gt_mask, margin=4), expand_mask(pred_mask)], axis=0)\n",
    "    return bar\n",
    "\n",
    "def make_result_video(preds, D, dt, vid_dir, policy=\"random\", visualize=True):\n",
    "    # sampling index and fetching relevant information\n",
    "    #policy = \"success\" # among [\"random\" | \"success\" | \"failure\"]\n",
    "    idx = sampling_idx(preds, policy)\n",
    "\n",
    "    vlen = preds[\"durations\"][idx]\n",
    "    qid = preds[\"qids\"][idx]\n",
    "    pred = preds[\"predictions\"][idx][0]\n",
    "    gt = preds[\"gts\"][idx]\n",
    "    vid = preds[\"vids\"][idx]\n",
    "    query = \" \".join(D.anns[qid][\"tokens\"])\n",
    "    assert vid == D.anns[qid][\"video_id\"], \"{} != {}\".format(vid, D.anns[qid][\"video_id\"])\n",
    "    assert vlen == D.anns[qid][\"duration\"], \"{} != {}\".format(vlen, D.anns[qid][\"duration\"])\n",
    "\n",
    "    # concatenate two videos where one for GT (red) and another for prediction (blue)\n",
    "    vw, mg, nw = 480, 20, 50 # video_width, margin, number of words at each line\n",
    "    if dt == \"anet\":\n",
    "        vname = vid[2:] + \".mp4\"\n",
    "    elif dt == \"charades\":\n",
    "        vname = vid + \".mp4\"\n",
    "    else:\n",
    "        raise NotImplementedError()\n",
    "    vid_path = vid_dir + vname\n",
    "    print(f\"video path: {vid_path}\")\n",
    "    vid_GT = concatenate_videoclips([\n",
    "        VideoFileClip(vid_path).subclip(0, gt[0]).margin(mg),\n",
    "        VideoFileClip(vid_path).subclip(gt[0], min(gt[1],vlen)).margin(mg, color=(255,0,0)), # red\n",
    "        VideoFileClip(vid_path).subclip(min(gt[1],vlen), vlen).margin(mg),\n",
    "        ])\n",
    "    vid_pred = concatenate_videoclips([\n",
    "        VideoFileClip(vid_path).subclip(0, pred[0]).margin(mg),\n",
    "        VideoFileClip(vid_path).subclip(pred[0], min(pred[1],vlen)).margin(mg, color=(0,0,255)), # blue\n",
    "        VideoFileClip(vid_path).subclip(min(pred[1],vlen), vlen).margin(mg),\n",
    "        ])\n",
    "    cc = clips_array([[vid_GT, vid_pred]]).resize(width=vw)\n",
    "    if FAST:\n",
    "        if dt == \"charades\":\n",
    "            factor = np.round(vlen / 20)\n",
    "        else:\n",
    "            factor = np.round(vlen / 30)\n",
    "        print(f\"speedup factor: {factor}\")\n",
    "        cc = cc.speedx(factor=factor)\n",
    "\n",
    "    print(f\"duration  : {vlen}\")\n",
    "    print(f\"vid       : {vid}\")\n",
    "    print(f\"Q         : {query}\")\n",
    "    print(f\"prediction: {pred}\")\n",
    "    print(f\"gt.       : {gt}\")\n",
    "    #cc.ipython_display(width=vw, maxduration=300)\n",
    "    #cc.ipython_display(maxduration=300)\n",
    "\n",
    "    # draw query in text image\n",
    "    query = \"Q: \" + query\n",
    "    nlines = np.int(np.ceil(len(query) / nw))\n",
    "    tline = []\n",
    "    for nl in range(nlines):\n",
    "        if nl == nlines-1:\n",
    "            cur_text = text_phantom(query[nl*nw:], vw)\n",
    "        else:\n",
    "            cur_text = text_phantom(query[nl*nw:nl*nw+nw], vw)\n",
    "        tline.append(cur_text)\n",
    "    q_text = np.concatenate(tline, axis=0)\n",
    "\n",
    "    # draw bar for GT and prediction\n",
    "    gt_idx = np.round(np.asarray(gt) / vlen * vw).astype(np.int)\n",
    "    pred_idx = np.round(np.asarray(pred) / vlen * vw).astype(np.int)\n",
    "    gt_mask, pred_mask = np.zeros((1,vw,3)), np.zeros((1,vw,3))\n",
    "    gt_mask[0, gt_idx[0]:gt_idx[1], 0] = 255 # Red color\n",
    "    pred_mask[0, pred_idx[0]:pred_idx[1], 2] = 255 # blue color\n",
    "    # expand masks for better visualization and concatenate them\n",
    "    bar = np.concatenate([expand_mask(gt_mask, margin=4), expand_mask(pred_mask)], axis=0)\n",
    "    \n",
    "    def add_query_and_bar(frame):\n",
    "        \"\"\" Add GT/prediction bar into frame\"\"\"\n",
    "        return np.concatenate([q_text, frame, bar], axis=0)\n",
    "\n",
    "    final_clip = cc.fl_image(add_query_and_bar)\n",
    "    \n",
    "    if visualize:\n",
    "        final_clip.ipython_display(maxduration=300)\n",
    "    else:\n",
    "        os.makedirs(f\"visualization/{dt}/{policy}\", exist_ok=True)\n",
    "        save_to = f\"visualization/{dt}/{policy}/{vid}.mp4\"\n",
    "        final_clip.write_videofile(save_to, fps=final_clip.fps)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* ActivityNet Captions Dataset (41.68% at R@0.5)\n",
    "* Charades Dataset (59.17% at R@0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_output(dt):\n",
    "    if dt == \"anet\":\n",
    "        from src.dataset import anet\n",
    "\n",
    "        config_path = \"pretraied_models/anet_LGI/config.yml\"\n",
    "        config = io_utils.load_yaml(config_path)[\"test_loader\"]\n",
    "        config[\"in_memory\"] = False\n",
    "        D = anet.ActivityNetCaptionsDataset(config)\n",
    "\n",
    "        pred_path = \"pretraied_models/anet_LGI/val_prediction.json\"\n",
    "        preds = io_utils.load_json(pred_path)\n",
    "        vid_dir = \"data/anet/raw_videos/validation/\"\n",
    "        \n",
    "    elif dt == \"charades\":\n",
    "        from src.dataset import charades\n",
    "\n",
    "        config_path = \"pretrained_models/charades_LGI/config.yml\"\n",
    "        config = io_utils.load_yaml(config_path)[\"test_loader\"]\n",
    "        config[\"in_memory\"] = False\n",
    "        D = charades.CharadesDataset(config)\n",
    "\n",
    "        pred_path = \"pretrained_models/charades_LGI/val_prediction.json\"\n",
    "        preds = io_utils.load_json(pred_path)\n",
    "        vid_dir = \"data/charades/raw_videos/\"\n",
    "        \n",
    "    return D, preds, vid_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load yaml file from pretrained_models/charades_LGI/config.yml\n",
      "Load json file from data/charades/preprocess/query_info/test_info_F1_L10_I3D.json\n",
      "Load json file from pretrained_models/charades_LGI/val_prediction.json\n"
     ]
    }
   ],
   "source": [
    "dt = \"charades\" # among anet|charades\n",
    "D, preds, vid_dir = load_output(dt)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create result videos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "video path: data/charades/raw_videos/SUJWY.mp4\n",
      "speedup factor: 2.0\n",
      "duration  : 32.54\n",
      "vid       : SUJWY\n",
      "Q         : the person begins laughing at something on the screen\n",
      "prediction: [16.866414813995362, 29.932574295997618]\n",
      "gt.       : [17.1, 29.7]\n",
      "Moviepy - Building video __temp__.mp4.\n",
      "MoviePy - Writing audio in __temp__TEMP_MPY_wvf_snd.mp3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                   \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MoviePy - Done.\n",
      "Moviepy - Writing video __temp__.mp4\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                              \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moviepy - Done !\n",
      "Moviepy - video ready __temp__.mp4\n",
      "video path: data/charades/raw_videos/SJ51G.mp4\n",
      "speedup factor: 2.0\n",
      "duration  : 30.79\n",
      "vid       : SJ51G\n",
      "Q         : person puts the blanket into a closet\n",
      "prediction: [7.205721558332443, 16.812131100296973]\n",
      "gt.       : [16.0, 21.7]\n",
      "Moviepy - Building video __temp__.mp4.\n",
      "MoviePy - Writing audio in __temp__TEMP_MPY_wvf_snd.mp3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                   \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MoviePy - Done.\n",
      "Moviepy - Writing video __temp__.mp4\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                            \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moviepy - Done !\n",
      "Moviepy - video ready __temp__.mp4\n"
     ]
    }
   ],
   "source": [
    "for i in range(1):\n",
    "    try:\n",
    "        make_result_video(preds, D, dt, vid_dir, \"success\", visualize=True)\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        print(\"error occured :(\")\n",
    "        continue\n",
    "\n",
    "for i in range(1):\n",
    "    try:\n",
    "        make_result_video(preds, D, dt, vid_dir, \"failure\", visualize=True)\n",
    "    except:\n",
    "        print(\"error occured :(\")\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:view] *",
   "language": "python",
   "name": "conda-env-view-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
